Gemma/[Gemma_2]Using_with_mistral_rs.ipynb (506 lines of code) (raw):

{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "BojqmYOsPk0A" }, "source": [ "##### Copyright 2024 Google LLC." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "id": "FtMmJ-pvPfNl" }, "outputs": [], "source": [ "# @title Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# https://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License." ] }, { "cell_type": "markdown", "metadata": { "id": "IyATsAVlPz1W" }, "source": [ "# Inference with Gemma 2 using mistral.rs\n", "\n", "[Gemma](https://ai.google.dev/gemma) is a family of lightweight, state-of-the-art open-source language models from Google. Built from the same research and technology used to create the Gemini models, Gemma models are text-to-text, decoder-only large language models (LLMs), available in English, with open weights, pre-trained variants, and instruction-tuned variants.\n", "Gemma models are well-suited for various text-generation tasks, including question-answering, summarization, and reasoning. Their relatively small size makes it possible to deploy them in environments with limited resources such as a laptop, desktop, or your cloud infrastructure, democratizing access to state-of-the-art AI models and helping foster innovation for everyone.\n", "\n", "[mistral.rs](https://github.com/EricLBuehler/mistral.rs) is a versatile framework for Large Language Model (LLM) inference supporting text-to-text and multimodal LLMs. It offers features like grammar support for structured outputs, and inference using LoRA-fine-tuned models, making it a powerful tool for a wide range of AI applications.\n", "\n", "In this notebook, you will learn how to prompt the Gemma 2 model in various ways using the **mistral.rs** Python APIs in a Google Colab environment.\n", "<table align=\"left\">\n", " <td>\n", " <a target=\"_blank\" href=\"https://colab.research.google.com/github/google-gemini/gemma-cookbook/blob/main/Gemma/[Gemma_2]Using_with_mistral_rs.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n", " </td>\n", "</table>" ] }, { "cell_type": "markdown", "metadata": { "id": "nFgaq_--Qg-O" }, "source": [ "## Setup\n", "\n", "### Select the Colab runtime\n", "To complete this tutorial, you can use the CPU runtime in Colab. Since the default accelerator for any Colab runtime is CPU, simply click the **Connect** button in the top-right corner of the Colab window.\n", "\n", "### Setup Hugging Face\n", "\n", "**Before you dive into the tutorial, let's get you set up with Hugging face:**\n", "\n", "1. **Hugging Face Account:** If you don't already have one, you can create a free Hugging Face account by clicking [here](https://huggingface.co/join).\n", "\n", "2. **Hugging Face Token:** Generate a Hugging Face access (preferably `write` permission) token by clicking [here](https://huggingface.co/settings/tokens). You'll need this token later in the tutorial.\n", "\n", "**Once you've completed these steps, you're ready to move on to the next section where you'll set up environment variables in your Colab environment.**" ] }, { "cell_type": "markdown", "metadata": { "id": "bLEUJYZ8QmGz" }, "source": [ "### Configure your HF token\n", "Add your Hugging Face token to the Colab Secrets manager to securely store it.\n", "\n", "1. Open your Google Colab notebook and click on the 🔑 Secrets tab in the left panel. <img src=\"https://storage.googleapis.com/generativeai-downloads/images/secrets.jpg\" alt=\"The Secrets tab is found on the left panel.\" width=50%>\n", "2. Create a new secret with the name `HF_TOKEN`.\n", "3. Copy/paste your HF token key into the Value input box of `HF_TOKEN`.\n", "4. Toggle the button on the left to allow notebook access to the secret." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "hK-qUiQGQbe5" }, "outputs": [], "source": [ "import os\n", "from google.colab import userdata\n", "\n", "# Note: `userdata.get` is a Colab API. If you're not using Colab, set the env\n", "# vars as appropriate for your system.\n", "os.environ[\"HF_TOKEN\"] = userdata.get(\"HF_TOKEN\")" ] }, { "cell_type": "markdown", "metadata": { "id": "5DrB-trDQsgE" }, "source": [ "### Install dependencies\n", "\n", "First, you must install the Python package for mistral.rs." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "LLjxxhk2Qrf_" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Collecting mistralrs==0.3.2\n", " Downloading mistralrs-0.3.2-cp310-cp310-manylinux_2_34_x86_64.whl.metadata (1.7 kB)\n", "Downloading mistralrs-0.3.2-cp310-cp310-manylinux_2_34_x86_64.whl (14.4 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m14.4/14.4 MB\u001b[0m \u001b[31m51.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hInstalling collected packages: mistralrs\n", "Successfully installed mistralrs-0.3.2\n" ] } ], "source": [ "!pip install mistralrs==0.3.2" ] }, { "cell_type": "markdown", "metadata": { "id": "_wsT7Eb_1esZ" }, "source": [ "## Overview\n", "\n", " In this notebook, you will explore how to prompt the Gemma 2 model using the Python APIs of mistral.rs. It's divided into the following sections:\n", "\n", "1. Load the Gemma 2 model\n", "2. Response generation\n", "3. Non-streaming chat completion\n", "4. Streaming chat completion\n", "5. Inference with grammar\n", "\n", "Additionally, you can run inference on Gemma 2 using the Rust APIs and command-line interface of mistral.rs. Read more about them in the [mistral.rs documentation](https://github.com/EricLBuehler/mistral.rs?tab=readme-ov-file#get-started-fast-).\n" ] }, { "cell_type": "markdown", "metadata": { "id": "tMfQBPRYsknZ" }, "source": [ "## Load the Gemma 2 model\n", "\n", "In this section, you will learn how to load the Gemma 2 model from Hugging Face Hub using the mistral.rs Python APIs.\n", "\n", "First, create an instance of the `Which` class specifying the model to load. Set `model_id` to the Hugging Face Hub repository ID of the desired Gemma 2 model variant. Specify `Architecture.Gemma2` for the arch parameter.\n", "\n", "Create an instance of the `Runner` class by providing the created `Which` instance. Set `token_source` to `env:HF_TOKEN` to indicate downloading the model from Hugging Face Hub using your existing Hugging Face token.\n", "\n", "Models can also be loaded from the local file system. Refer to the [mistral.rs get models guide](https://github.com/EricLBuehler/mistral.rs?tab=readme-ov-file#getting-models) for more details." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "VYFDjCBBtBhS" }, "outputs": [], "source": [ "from mistralrs import Runner, Which, Architecture\n", "\n", "runner = Runner(\n", " which=Which.Plain(\n", " model_id=\"google/gemma-2-2b-it\",\n", " arch=Architecture.Gemma2,\n", " ),\n", " token_source=\"env:HF_TOKEN\",\n", ")" ] }, { "cell_type": "markdown", "metadata": { "id": "YoXapKI3vE4m" }, "source": [ "## Response generation\n", "\n", "Next, you'll send a prompt to Gemma 2 for inference using mistral.rs.\n", "\n", "To achieve this, create an instance of the `CompletionRequest` class, specifying your desired prompt in the `prompt` parameter. Set the model parameter to \"gemma2\". mistral.rs allows you to configure common LLM parameters like `top_p`, `max_tokens`, and `temperature` within `CompletionRequest`. You can find the full list in the [class definition of `CompletionRequest`](https://github.com/EricLBuehler/mistral.rs/blob/458dc5f447161904ee9191fdec1dc5d4f039af54/mistralrs-pyo3/mistralrs.pyi#L44).\n", "\n", "Finally, call the `send_completion_request` method on the `runner` object, passing the newly created `CompletionRequest` instance as the `request` argument." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "niIzVlPbeSvO" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "A black hole is a region of spacetime where gravity is so strong that nothing, not even light, can escape. \n", "\n", "Here's a breakdown:\n", "\n", "**What causes a black hole?**\n", "\n", "* **Stellar Collapse:** When a massive star runs out of fuel, it collapses under its own gravity. This collapse can create a black hole if the star's mass is enough.\n", "* **Supermassive Black Holes:** These are found at the centers of most galaxies, including our own Milky Way. Their formation is still a mystery, but they likely formed from the merging of smaller black holes or from the collapse of massive gas clouds.\n", "\n", "**Key Features of a Black Hole:**\n", "\n", "* **Event Horizon:** This is the boundary around a black hole beyond which escape is impossible. Once something crosses the event horizon, it's trapped forever.\n", "* **Singularity:** This is the theoretical point at the center of a black hole where all the matter is compressed into an infinitely small point.\n", "* **Gravitational Pull:** Black holes have immense gravitational pull, which warps spacetime around them.\n", "\n", "**How do we know black holes exist?**\n", "\n", "* **Gravitational Effects:** We can detect black holes by observing their gravitational effects on nearby stars and gas.\n", "* **X-ray Emissions:** As matter falls into a black hole, it heats up and emits X-rays, which can be detected by telescopes.\n", "* **Gravitational Waves:** When black holes collide, they create ripples in spacetime called gravitational waves, which can be detected by specialized instruments.\n", "\n", "**Interesting Facts:**\n", "\n", "* Black holes are not actually \"holes\" in space, but rather regions of extreme density and gravity.\n", "* The size of a black hole is measured by its \"Schwarzschild radius,\" which is the distance from the center at which the escape velocity equals the speed of light.\n", "* Black holes are still a subject of active research, and scientists are constantly learning more about them.\n", "\n", "\n", "Let me know if you have any other questions! \n", "\n", "Usage {\n", " completion_tokens: 415,\n", " prompt_tokens: 7,\n", " total_tokens: 422,\n", " avg_tok_per_sec: 1.3271735,\n", " avg_prompt_tok_per_sec: 1.4326648,\n", " avg_compl_tok_per_sec: 1.3255271,\n", " total_time_sec: 317.969,\n", " total_prompt_time_sec: 4.886,\n", " total_completion_time_sec: 313.083,\n", "}\n" ] } ], "source": [ "from mistralrs import CompletionRequest\n", "\n", "request = CompletionRequest(\n", " model=\"gemma2\",\n", " prompt=\"What is a black hole?\",\n", " max_tokens=512,\n", " top_p=0.1,\n", " temperature=0.1,\n", ")\n", "\n", "response = runner.send_completion_request(request)\n", "\n", "print(response.choices[0].text)\n", "print(response.usage)" ] }, { "cell_type": "markdown", "metadata": { "id": "vu1jh_nLMXyS" }, "source": [ "## Non-streaming chat completion\n", "\n", "In addition to basic prompting, mistral.rs also supports multi-turn conversations with LLMs.\n", "\n", "For multi-turn conversations, you'll maintain a list of dictionaries, each representing a single message in the conversation history with Gemma 2. These dictionaries should include a key named `role` specifying whether the user or the assistant (model) generated the message.\n", "\n", "Create an instance of `ChatCompletionRequest` by passing the conversation history to the `messages` parameter. Then, invoke the `send_chat_completion_request` method of the `runner` with the newly created `ChatCompletionRequest` instance as the `request` argument.\n", "\n", "You can also configure any other model parameters you want to use. Refer to the the [class definition of `ChatCompletionRequest`](https://github.com/EricLBuehler/mistral.rs/blob/458dc5f447161904ee9191fdec1dc5d4f039af54/mistralrs-pyo3/mistralrs.pyi#L11)." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ffjhPyHpvEJM" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Neil Armstrong and Buzz Aldrin were both American astronauts. They represented the **United States** during the Apollo 11 mission. \n", "\n", "Usage {\n", " completion_tokens: 29,\n", " prompt_tokens: 74,\n", " total_tokens: 103,\n", " avg_tok_per_sec: 2.808376,\n", " avg_prompt_tok_per_sec: 4.5204644,\n", " avg_compl_tok_per_sec: 1.4281493,\n", " total_time_sec: 36.676,\n", " total_prompt_time_sec: 16.37,\n", " total_completion_time_sec: 20.306,\n", "}\n" ] } ], "source": [ "from mistralrs import ChatCompletionRequest\n", "\n", "messages = [\n", " {\n", " \"role\": \"user\",\n", " \"content\": \"Who are the first humans to land on moon?\",\n", " },\n", " {\n", " \"role\": \"assistant\",\n", " \"content\": \"The first humans to land on the Moon were **Neil Armstrong and Buzz Aldrin** of the Apollo 11 mission on **July 20, 1969**.\",\n", " },\n", " {\n", " \"role\": \"user\",\n", " \"content\": \"Which country did they belong to?\",\n", " },\n", "]\n", "\n", "request = ChatCompletionRequest(\n", " model=\"gemma2\",\n", " messages=messages,\n", " max_tokens=256,\n", " presence_penalty=1.0,\n", " top_p=0.1,\n", " temperature=0.1,\n", ")\n", "\n", "response = runner.send_chat_completion_request(request)\n", "\n", "print(response.choices[0].message.content)\n", "print(response.usage)" ] }, { "cell_type": "markdown", "metadata": { "id": "7W-F2n_xhWRl" }, "source": [ "Notice how the model generated the response for the last query from the user based on the previous conversation history." ] }, { "cell_type": "markdown", "metadata": { "id": "qZehF79UOVWg" }, "source": [ "## Streaming chat completion\n", "\n", "mistral.rs also supports obtaining streamed responses from the model during multi-turn conversations. To enable streaming, set the `stream` parameter of `ChatCompletionRequest` to `True`.\n", "\n", "The `send_chat_completion_request` function will return an iterable object. You can access each chunk of the streamed response by iterating over this object.\n", "\n", "Pass the conversation history (`messages`) you created earlier to the `ChatCompletionRequest` instance with `stream=True`." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "nxNg6cXFog29" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Neil Armstrong and Buzz Aldrin were both American astronauts. They represented the **United States** during the Apollo 11 mission. \n" ] } ], "source": [ "request = ChatCompletionRequest(\n", " model=\"gemma2\",\n", " messages=messages,\n", " max_tokens=256,\n", " presence_penalty=1.0,\n", " top_p=0.1,\n", " temperature=0.1,\n", " stream=True,\n", ")\n", "\n", "response = runner.send_chat_completion_request(request)\n", "\n", "for chunk in response:\n", " print(chunk.choices[0].delta.content, end=\"\")" ] }, { "cell_type": "markdown", "metadata": { "id": "3RALnmwQP6WQ" }, "source": [ "## Inference with grammar\n", "\n", "Specifying a grammar when creating a `ChatCompletionRequest` or `CompletionRequest` allows you to constrain the model's output to a specific format.\n", "\n", "To ensure the model generates responses that match a regular expression, set the parameter `grammar_type` to \"regex\" and `grammar` to the desired regular expression string.\n", "\n", "The following code snippet illustrates how to restrict the model's response to two-digit numbers using regular expression grammar:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Gi2FGoXJv-K9" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "13\n", "Usage {\n", " completion_tokens: 5,\n", " prompt_tokens: 31,\n", " total_tokens: 36,\n", " avg_tok_per_sec: 3.115804,\n", " avg_prompt_tok_per_sec: 3.5123498,\n", " avg_compl_tok_per_sec: 1.8328446,\n", " total_time_sec: 11.554,\n", " total_prompt_time_sec: 8.826,\n", " total_completion_time_sec: 2.728,\n", "}\n" ] } ], "source": [ "request = CompletionRequest(\n", " model=\"gemma2\",\n", " prompt=\"What is the next number in the fibonnaci series: 1, 1, 2, 3, 5, 8 ?\",\n", " grammar_type=\"regex\",\n", " grammar=r\"\\d{2}\",\n", " top_p=0.1,\n", " temperature=0.1,\n", ")\n", "\n", "response = runner.send_completion_request(request)\n", "\n", "print(response.choices[0].text)\n", "print(response.usage)" ] }, { "cell_type": "markdown", "metadata": { "id": "X72NYX92nv40" }, "source": [ "mistral.rs offers support for various grammar types beyond regular expressions. Refer to the [mistral.rs documentation](https://github.com/EricLBuehler/mistral.rs/tree/master?tab=readme-ov-file#description) for more details." ] }, { "cell_type": "markdown", "metadata": { "id": "vlRAC8OFa1pQ" }, "source": [ "These are just a few examples of how you can perform inference with Gemma 2 using mistral.rs. To explore its capabilities further, you can refer to the [mistral.rs documentation](https://github.com/EricLBuehler/mistral.rs/tree/master?tab=readme-ov-file#--mistralrs).\n" ] } ], "metadata": { "colab": { "name": "[Gemma_2]Using_with_mistral_rs.ipynb", "toc_visible": true }, "kernelspec": { "display_name": "Python 3", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 0 }